// Copyright 2025 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Implementation of generating multitarget modules according to the
// *SPV_INTEL_function_variants* extension
//
// Multitarget module is generated by linking separate modules: a base module
// and variant modules containing device-specific variants of the functions in
// the base module. The behavior is controlled by Comma-Separated Values (CSV)
// files passed to the following flags:
// --fnvar-targets: Required columns:
//   module   - module file name
//   target   - device target ISA value
//   features - feature values for the target separated by '/' (FEAT_SEP)
// --fnvar-architectures: Required columns:
//   module       - module file name
//   category     - device category value
//   family       - device family value
//   op           - opcode of the comparison instruction
//   architecture - device architecture
// The values (except module) are decimal strings with their meaning defined in
// the 'targets registry' as described in the extension spec. The decimal
// strings may only encode unsigned 32-bit integers (characters 0-9), possibly
// with leading zeros.
//
// In addition, --fnvar-capabilities generates OpSpecConstantCapabilitiesINTEL
// for each module with operands corresponding to the module's capabilities.
//
// Each line in the targets/architectures CSV file defines one
// OpSpecConstant<Target/Architecture>INTEL instruction, the columns correspond
// to the operands of these instructions. One module can have multiple lines, in
// which case they are combined into a single boolean spec constant using
// OpSpecConstantOp and OpLogicalOr (except when category and family in the
// architectures CSV are the same, then the lines are combined with
// OpLogicalAnd). For example, the following architectures CSV
//
//     module,category,family,op,architecture
//     foo.spv,1,7,174,1
//     foo.spv,1,7,178,3
//     foo.spv,1,8,170,1
//
// is combined as follows:
//
//          %53 = OpSpecConstantArchitectureINTEL %bool 1 7 174 1
//          %54 = OpSpecConstantArchitectureINTEL %bool 1 7 178 3
//          %55 = OpSpecConstantArchitectureINTEL %bool 1 8 170 1
//          %56 = OpSpecConstantOp %bool LogicalAnd %53 %54
//     %foo_spv = OpSpecConstantOp %bool LogicalOr %55 %56
//
// The %foo_spv is annotated with OpName "foo.spv" (the module's name) which
// serves as an identifier to find the constant later. We cannot use IDs for it
// because the IDs get shifted during linking.
//
// The first module passed to `spirv-link` is considered the 'base' module. For
// example, if base module defines functions 'foo' and 'bar' and the other
// modules define only 'foo', only the 'foo' is treated as a function variant
// guarded by spec constants. The 'bar' function will be untouched and therefore
// present for all variants. The function variants are matched by name, and
// therefore they must either have an entry point, or an Export linkage
// attribute.

#ifndef FNVAR_H
#define FNVAR_H

#include <map>
#include <set>
#include <string>
#include <vector>

#include "source/opt/ir_context.h"
#include "source/opt/module.h"
#include "spirv-tools/linker.hpp"

namespace spvtools {

using opt::IRContext;
using opt::Module;

// Map of instruction hash -> which variants are using the instruction (denoted
// by the index to the variants vector)
using FnVarUsage = std::unordered_map<size_t, std::vector<size_t>>;

// Map of base function call ID -> variant functions corresponding to the
// called function (along with the variant name)
using BaseFnCalls =
    std::map<uint32_t,
             std::vector<std::pair<std::string, const opt::Function*>>>;

constexpr char FNVAR_EXT_NAME[] = "SPV_INTEL_function_variants";
constexpr uint32_t FNVAR_REGISTRY_VERSION = 0;
constexpr char FEAT_SEP = '/';

struct FnVarArchDef {
  uint32_t category;
  uint32_t family;
  uint32_t op;
  uint32_t architecture;
};

struct FnVarTargetDef {
  uint32_t target;
  std::vector<uint32_t> features;
};

// Definition of a variant
//
// Stores architecture and target definitions inferred from lines in the CSV
// files for a single module (as well as a pointer to the Module).
class VariantDef {
 public:
  VariantDef(bool isbase, std::string nm, Module* mod)
      : is_base(isbase), name(nm), module(mod) {}

  bool IsBase() const { return this->is_base; }
  std::string GetName() const { return this->name; }
  Module* GetModule() const { return this->module; }

  void AddArchDef(uint32_t category, uint32_t family, uint32_t op,
                  uint32_t architecture) {
    FnVarArchDef arch_def;
    arch_def.category = category;
    arch_def.family = family;
    arch_def.op = op;
    arch_def.architecture = architecture;
    this->arch_defs.push_back(arch_def);
  }
  const std::vector<FnVarArchDef>& GetArchDefs() const {
    return this->arch_defs;
  }

  void AddTgtDef(uint32_t target, std::vector<uint32_t> features) {
    FnVarTargetDef tgt_def;
    tgt_def.target = target;
    tgt_def.features = features;
    this->tgt_defs.push_back(tgt_def);
  }
  const std::vector<FnVarTargetDef>& GetTgtDefs() const {
    return this->tgt_defs;
  }

  void InferCapabilities() {
    for (const auto& cap_inst : module->capabilities()) {
      capabilities.insert(spv::Capability(cap_inst.GetOperand(0).words[0]));
    }
  }
  const std::set<spv::Capability>& GetCapabilities() const {
    return this->capabilities;
  }

 private:
  bool is_base;
  std::string name;
  Module* module;
  std::vector<FnVarTargetDef> tgt_defs;
  std::vector<FnVarArchDef> arch_defs;
  std::set<spv::Capability> capabilities;
};

// Collection of VariantDef instances
//
// Apart from being a wrapper around a vector of VariantDef instances, it
// defines the main API for generating SPV_INTEL_function_variants instructions
// based on the CSV files.
class VariantDefs {
 public:
  // Returns last error message.
  std::string GetErr() { return err_.str(); }

  // Processes CSV files passed to the CLI and populate _variants.
  //
  // Returns true on success, false on error.
  bool ProcessFnVar(const LinkerOptions& options,
                    const std::vector<Module*>& modules);

  // Analyses each variant def module and generates those instructions that are
  // module-specific, ie., not requiring knowledge from other modules.
  //
  // Returns true on success, false on error.
  bool ProcessVariantDefs();

  // Generates basic instructions required for this extension to work.
  void GenerateHeader(IRContext* linked_context);

  // Generates instructions from this extension that result from combining
  // several variant def modules.
  void CombineVariantInstructions(IRContext* linked_context);

 private:
  // Adds a boolean type to every module if there is none.
  //
  // These are necessary for spec constants.
  void EnsureBoolType();

  // Collects which combinable instructions are defined in which modules
  void CollectVarInsts();

  // Generates OpSpecConstant<Target/Architecture/Capabilities>INTEL and
  // combines them as necessary. Also converts entry points to conditional ones
  // and decorates module-specific instructions with ConditionalINTEL.
  //
  // Returns true on success, false on error.
  bool GenerateFnVarConstants();

  // Determines which functions in the base module are called by which function
  // variants.
  void CollectBaseFnCalls();

  // Combines OpFunctionCall instructions collected with CollectBaseFnCalls()
  // using conditional copy.
  void CombineBaseFnCalls(IRContext* linked_context);

  // Decorates instructions shared between modules with ConditionalINTEL or
  // generates conditional capabilities and extensions, depending on which
  // variants are used by each.
  void CombineInstructions(IRContext* linked_context);

  // Accumulates all errors encountered during processing.
  std::stringstream err_;

  // Collection of VariantDef instances
  std::vector<VariantDef> variant_defs_;

  // Used for combining OpFunctionCall instructions
  BaseFnCalls base_fn_calls_;

  // Used for determining which function variant uses which (applicable)
  // instruction
  FnVarUsage fnvar_usage_;
};

}  // namespace spvtools

#endif  // FNVAR_H
